package com.zhoug.player.exoplayer.v2;

import java.util.Arrays;
import java.util.Collections;

import androidx.annotation.Nullable;
import androidx.media3.common.C;
import androidx.media3.common.Format;
import androidx.media3.common.MimeTypes;
import androidx.media3.common.ParserException;
import androidx.media3.common.util.Assertions;
import androidx.media3.common.util.Log;
import androidx.media3.common.util.ParsableBitArray;
import androidx.media3.common.util.ParsableByteArray;
import androidx.media3.common.util.UnstableApi;
import androidx.media3.common.util.Util;
import androidx.media3.extractor.AacUtil;
import androidx.media3.extractor.DiscardingTrackOutput;
import androidx.media3.extractor.ExtractorOutput;
import androidx.media3.extractor.TrackOutput;
import androidx.media3.extractor.ts.ElementaryStreamReader;
import androidx.media3.extractor.ts.TsPayloadReader;

import static androidx.media3.common.util.Assertions.checkState;
import static java.lang.Math.min;

/**
 * @author: zhoug
 * @since: 2025-07-23
 */
@UnstableApi
public class FixedAdtsReader implements ElementaryStreamReader {
    private static final int DEFAULT_SAMPLE_RATE = 44100; // 默认采样率
    private static final int DEFAULT_CHANNEL_COUNT = 2;   // 默认声道数

    private static final String TAG = "AdtsReader";

    private static final int STATE_FINDING_SAMPLE = 0;
    private static final int STATE_CHECKING_ADTS_HEADER = 1;
    private static final int STATE_READING_ID3_HEADER = 2;
    private static final int STATE_READING_ADTS_HEADER = 3;
    private static final int STATE_READING_SAMPLE = 4;

    private static final int HEADER_SIZE = 5;
    private static final int CRC_SIZE = 2;

    // Match states used while looking for the next sample
    private static final int MATCH_STATE_VALUE_SHIFT = 8;
    private static final int MATCH_STATE_START = 1 << MATCH_STATE_VALUE_SHIFT;
    private static final int MATCH_STATE_FF = 2 << MATCH_STATE_VALUE_SHIFT;
    private static final int MATCH_STATE_I = 3 << MATCH_STATE_VALUE_SHIFT;
    private static final int MATCH_STATE_ID = 4 << MATCH_STATE_VALUE_SHIFT;

    private static final int ID3_HEADER_SIZE = 10;
    private static final int ID3_SIZE_OFFSET = 6;
    private static final byte[] ID3_IDENTIFIER = {'I', 'D', '3'};
    private static final int VERSION_UNSET = -1;

    private final boolean exposeId3;
    private final ParsableBitArray adtsScratch;
    private final ParsableByteArray id3HeaderBuffer;
    @Nullable
    private final String language;
    private final @C.RoleFlags int roleFlags;

    private String formatId;
    private TrackOutput output;
    private TrackOutput id3Output;

    private int state;
    private int bytesRead;

    private int matchState;

    private boolean hasCrc;
    private boolean foundFirstFrame;

    // Used to verifies sync words
    private int firstFrameVersion;
    private int firstFrameSampleRateIndex;

    private int currentFrameVersion;

    // Used when parsing the header.
    private boolean hasOutputFormat;
    private long sampleDurationUs;
    private int sampleSize;

    // Used when reading the samples.
    private long timeUs;

    private TrackOutput currentOutput;
    private long currentSampleDuration;

    /**
     * @param exposeId3 True if the reader should expose ID3 information.
     */
    public FixedAdtsReader(boolean exposeId3) {
        this(exposeId3, null, /* roleFlags= */ 0);
    }

    /**
     * @param exposeId3 True if the reader should expose ID3 information.
     * @param language  Track language.
     * @param roleFlags Track role flags.
     */
    public FixedAdtsReader(boolean exposeId3, @Nullable String language, @C.RoleFlags int roleFlags) {
        adtsScratch = new ParsableBitArray(new byte[HEADER_SIZE + CRC_SIZE]);
        id3HeaderBuffer = new ParsableByteArray(Arrays.copyOf(ID3_IDENTIFIER, ID3_HEADER_SIZE));
        setFindingSampleState();
        firstFrameVersion = VERSION_UNSET;
        firstFrameSampleRateIndex = C.INDEX_UNSET;
        sampleDurationUs = C.TIME_UNSET;
        timeUs = C.TIME_UNSET;
        this.exposeId3 = exposeId3;
        this.language = language;
        this.roleFlags = roleFlags;
    }

    /**
     * Returns whether an integer matches an ADTS SYNC word.
     */
    public static boolean isAdtsSyncWord(int candidateSyncWord) {
        return (candidateSyncWord & 0xFFF6) == 0xFFF0;
    }

    @Override
    public void seek() {
        timeUs = C.TIME_UNSET;
        resetSync();
    }

    @Override
    public void createTracks(ExtractorOutput extractorOutput, TsPayloadReader.TrackIdGenerator idGenerator) {
        idGenerator.generateNewId();
        formatId = idGenerator.getFormatId();
        output = extractorOutput.track(idGenerator.getTrackId(), C.TRACK_TYPE_AUDIO);
        currentOutput = output;
        if (exposeId3) {
            idGenerator.generateNewId();
            id3Output = extractorOutput.track(idGenerator.getTrackId(), C.TRACK_TYPE_METADATA);
            id3Output.format(
                    new Format.Builder()
                            .setId(idGenerator.getFormatId())
                            .setSampleMimeType(MimeTypes.APPLICATION_ID3)
                            .build());
        } else {
            id3Output = new DiscardingTrackOutput();
        }
    }

    @Override
    public void packetStarted(long pesTimeUs, @TsPayloadReader.Flags int flags) {
        timeUs = pesTimeUs;
    }

    @Override
    public void consume(ParsableByteArray data) throws ParserException {
        assertTracksCreated();
        while (data.bytesLeft() > 0) {
            switch (state) {
                case STATE_FINDING_SAMPLE:
                    findNextSample(data);
                    break;
                case STATE_READING_ID3_HEADER:
                    if (continueRead(data, id3HeaderBuffer.getData(), ID3_HEADER_SIZE)) {
                        parseId3Header();
                    }
                    break;
                case STATE_CHECKING_ADTS_HEADER:
                    checkAdtsHeader(data);
                    break;
                case STATE_READING_ADTS_HEADER:
                    int targetLength = hasCrc ? HEADER_SIZE + CRC_SIZE : HEADER_SIZE;
                    if (continueRead(data, adtsScratch.data, targetLength)) {
                        parseAdtsHeader();
                    }
                    break;
                case STATE_READING_SAMPLE:
                    readSample(data);
                    break;
                default:
                    throw new IllegalStateException();
            }
        }

    }

    private void fixAndConsume(ParsableByteArray data) {

    }

    private int getSampleRate(int index) {
        return switch (index) {
            case 0 -> 96000;
            case 1 -> 88200;
            case 2 -> 64000;
            case 3 -> 48000;
            case 4 -> 44100;
            case 5 -> 32000;
            case 6 -> 24000;
            case 7 -> 22050;
            case 8 -> 16000;
            case 9 -> 12000;
            case 10 -> 11025;
            case 11 -> 8000;
            default -> Format.NO_VALUE;
        };
    }

    private int getChannelCount(int config) {
        return config == 7 ? 8 : config; // 特殊处理7.1声道
    }

    @Override
    public void packetFinished(boolean isEndOfInput) {
        // Do nothing.
    }

    /**
     * Returns the duration in microseconds per sample, or {@link C#TIME_UNSET} if the sample duration
     * is not available.
     */
    public long getSampleDurationUs() {
        return sampleDurationUs;
    }

    private void resetSync() {
        foundFirstFrame = false;
        setFindingSampleState();
    }

    /**
     * Continues a read from the provided {@code source} into a given {@code target}. It's assumed
     * that the data should be written into {@code target} starting from an offset of zero.
     *
     * @param source       The source from which to read.
     * @param target       The target into which data is to be read.
     * @param targetLength The target length of the read.
     * @return Whether the target length was reached.
     */
    private boolean continueRead(ParsableByteArray source, byte[] target, int targetLength) {
        int bytesToRead = min(source.bytesLeft(), targetLength - bytesRead);
        source.readBytes(target, bytesRead, bytesToRead);
        bytesRead += bytesToRead;
        return bytesRead == targetLength;
    }

    /**
     * Sets the state to STATE_FINDING_SAMPLE.
     */
    private void setFindingSampleState() {
        state = STATE_FINDING_SAMPLE;
        bytesRead = 0;
        matchState = MATCH_STATE_START;
    }

    /**
     * Sets the state to STATE_READING_ID3_HEADER and resets the fields required for {@link
     * #parseId3Header()}.
     */
    private void setReadingId3HeaderState() {
        state = STATE_READING_ID3_HEADER;
        bytesRead = ID3_IDENTIFIER.length;
        sampleSize = 0;
        id3HeaderBuffer.setPosition(0);
    }

    /**
     * Sets the state to STATE_READING_SAMPLE.
     *
     * @param outputToUse           TrackOutput object to write the sample to
     * @param currentSampleDuration Duration of the sample to be read
     * @param priorReadBytes        Size of prior read bytes
     * @param sampleSize            Size of the sample
     */
    private void setReadingSampleState(
            TrackOutput outputToUse, long currentSampleDuration, int priorReadBytes, int sampleSize) {
        state = STATE_READING_SAMPLE;
        bytesRead = priorReadBytes;
        this.currentOutput = outputToUse;
        this.currentSampleDuration = currentSampleDuration;
        this.sampleSize = sampleSize;
    }

    /**
     * Sets the state to STATE_READING_ADTS_HEADER.
     */
    private void setReadingAdtsHeaderState() {
        state = STATE_READING_ADTS_HEADER;
        bytesRead = 0;
    }

    /**
     * Sets the state to STATE_CHECKING_ADTS_HEADER.
     */
    private void setCheckingAdtsHeaderState() {
        state = STATE_CHECKING_ADTS_HEADER;
        bytesRead = 0;
    }

    /**
     * Locates the next sample start, advancing the position to the byte that immediately follows
     * identifier. If a sample was not located, the position is advanced to the limit.
     *
     * @param pesBuffer The buffer whose position should be advanced.
     */
    private void findNextSample(ParsableByteArray pesBuffer) {
        byte[] adtsData = pesBuffer.getData();
        int position = pesBuffer.getPosition();
        int endOffset = pesBuffer.limit();
        while (position < endOffset) {
            int data = adtsData[position++] & 0xFF;
            if (matchState == MATCH_STATE_FF && isAdtsSyncBytes((byte) 0xFF, (byte) data)) {
                if (foundFirstFrame
                        || checkSyncPositionValid(pesBuffer, /* syncPositionCandidate= */ position - 2)) {
                    currentFrameVersion = (data & 0x8) >> 3;
                    hasCrc = (data & 0x1) == 0;
                    if (!foundFirstFrame) {
                        setCheckingAdtsHeaderState();
                    } else {
                        setReadingAdtsHeaderState();
                    }
                    pesBuffer.setPosition(position);
                    return;
                }
            }

            switch (matchState | data) {
                case MATCH_STATE_START | 0xFF:
                    matchState = MATCH_STATE_FF;
                    break;
                case MATCH_STATE_START | 'I':
                    matchState = MATCH_STATE_I;
                    break;
                case MATCH_STATE_I | 'D':
                    matchState = MATCH_STATE_ID;
                    break;
                case MATCH_STATE_ID | '3':
                    setReadingId3HeaderState();
                    pesBuffer.setPosition(position);
                    return;
                default:
                    if (matchState != MATCH_STATE_START) {
                        // If matching fails in a later state, revert to MATCH_STATE_START and
                        // check this byte again
                        matchState = MATCH_STATE_START;
                        position--;
                    }
                    break;
            }
        }
        pesBuffer.setPosition(position);
    }

    /**
     * Peeks the Adts header of the current frame and checks if it is valid. If the header is valid,
     * transition to {@link #STATE_READING_ADTS_HEADER}; else, transition to {@link
     * #STATE_FINDING_SAMPLE}.
     */
    private void checkAdtsHeader(ParsableByteArray buffer) {
        if (buffer.bytesLeft() == 0) {
            // Not enough data to check yet, defer this check.
            return;
        }
        // Peek the next byte of buffer into scratch array.
        adtsScratch.data[0] = buffer.getData()[buffer.getPosition()];

        adtsScratch.setPosition(2);
        int currentFrameSampleRateIndex = adtsScratch.readBits(4);
        if (firstFrameSampleRateIndex != C.INDEX_UNSET
                && currentFrameSampleRateIndex != firstFrameSampleRateIndex) {
            // Invalid header.
            resetSync();
            return;
        }

        if (!foundFirstFrame) {
            foundFirstFrame = true;
            firstFrameVersion = currentFrameVersion;
            firstFrameSampleRateIndex = currentFrameSampleRateIndex;
        }
        setReadingAdtsHeaderState();
    }

    /**
     * Checks whether a candidate SYNC word position is likely to be the position of a real SYNC word.
     * The caller must check that the first byte of the SYNC word is 0xFF before calling this method.
     * This method performs the following checks:
     *
     * <ul>
     *   <li>The MPEG version of this frame must match the previously detected version.
     *   <li>The sample rate index of this frame must match the previously detected sample rate index.
     *   <li>The frame size must be at least 7 bytes
     *   <li>The bytes following the frame must be either another SYNC word with the same MPEG
     *       version, or the start of an ID3 header.
     * </ul>
     * <p>
     * With the exception of the first check, if there is insufficient data in the buffer then checks
     * are optimistically skipped and {@code true} is returned.
     *
     * @param pesBuffer             The buffer containing at data to check.
     * @param syncPositionCandidate The candidate SYNC word position. May be -1 if the first byte of
     *                              the candidate was the last byte of the previously consumed buffer.
     * @return True if all checks were passed or skipped, indicating the position is likely to be the
     * position of a real SYNC word. False otherwise.
     */
    private boolean checkSyncPositionValid(ParsableByteArray pesBuffer, int syncPositionCandidate) {
        pesBuffer.setPosition(syncPositionCandidate + 1);
        if (!tryRead(pesBuffer, adtsScratch.data, 1)) {
            return false;
        }

        // The MPEG version of this frame must match the previously detected version.
        adtsScratch.setPosition(4);
        int currentFrameVersion = adtsScratch.readBits(1);
        if (firstFrameVersion != VERSION_UNSET && currentFrameVersion != firstFrameVersion) {
            return false;
        }

        // The sample rate index of this frame must match the previously detected sample rate index.
        if (firstFrameSampleRateIndex != C.INDEX_UNSET) {
            if (!tryRead(pesBuffer, adtsScratch.data, 1)) {
                // Insufficient data for further checks.
                return true;
            }
            adtsScratch.setPosition(2);
            int currentFrameSampleRateIndex = adtsScratch.readBits(4);
            if (currentFrameSampleRateIndex != firstFrameSampleRateIndex) {
                return false;
            }
            pesBuffer.setPosition(syncPositionCandidate + 2);
        }

        // The frame size must be at least 7 bytes.
        if (!tryRead(pesBuffer, adtsScratch.data, 4)) {
            // Insufficient data for further checks.
            return true;
        }
        adtsScratch.setPosition(14);
        int frameSize = adtsScratch.readBits(13);
        if (frameSize < 7) {
            return false;
        }

        // The bytes following the frame must be either another SYNC word with the same MPEG version, or
        // the start of an ID3 header.
        byte[] data = pesBuffer.getData();
        int dataLimit = pesBuffer.limit();
        int nextSyncPosition = syncPositionCandidate + frameSize;
        if (nextSyncPosition >= dataLimit) {
            // Insufficient data for further checks.
            return true;
        }
        if (data[nextSyncPosition] == (byte) 0xFF) {
            if (nextSyncPosition + 1 == dataLimit) {
                // Insufficient data for further checks.
                return true;
            }
            return isAdtsSyncBytes((byte) 0xFF, data[nextSyncPosition + 1])
                    && ((data[nextSyncPosition + 1] & 0x8) >> 3) == currentFrameVersion;
        } else {
            if (data[nextSyncPosition] != 'I') {
                return false;
            }
            if (nextSyncPosition + 1 == dataLimit) {
                // Insufficient data for further checks.
                return true;
            }
            if (data[nextSyncPosition + 1] != 'D') {
                return false;
            }
            if (nextSyncPosition + 2 == dataLimit) {
                // Insufficient data for further checks.
                return true;
            }
            return data[nextSyncPosition + 2] == '3';
        }
    }

    private boolean isAdtsSyncBytes(byte firstByte, byte secondByte) {
        int syncWord = (firstByte & 0xFF) << 8 | (secondByte & 0xFF);
        return isAdtsSyncWord(syncWord);
    }

    /**
     * Reads {@code targetLength} bytes into target, and returns whether the read succeeded.
     */
    private boolean tryRead(ParsableByteArray source, byte[] target, int targetLength) {
        if (source.bytesLeft() < targetLength) {
            return false;
        }
        source.readBytes(target, /* offset= */ 0, targetLength);
        return true;
    }

    /**
     * Parses the Id3 header.
     */
    private void parseId3Header() {
        id3Output.sampleData(id3HeaderBuffer, ID3_HEADER_SIZE);
        id3HeaderBuffer.setPosition(ID3_SIZE_OFFSET);
        setReadingSampleState(
                id3Output, 0, ID3_HEADER_SIZE, id3HeaderBuffer.readSynchSafeInt() + ID3_HEADER_SIZE);
    }

    /**
     * Parses the sample header.
     */
    private void parseAdtsHeader() throws ParserException {
        adtsScratch.setPosition(0);

        if (!hasOutputFormat) {
            int audioObjectType = adtsScratch.readBits(2) + 1;
            if (audioObjectType != 2) {
                // The stream indicates AAC-Main (1), AAC-SSR (3) or AAC-LTP (4). When the stream indicates
                // AAC-Main it's more likely that the stream contains HE-AAC (5), which cannot be
                // represented correctly in the 2 bit audio_object_type field in the ADTS header. In
                // practice when the stream indicates AAC-SSR or AAC-LTP it more commonly contains AAC-LC or
                // HE-AAC. Since most Android devices don't support AAC-Main, AAC-SSR or AAC-LTP, and since
                // indicating AAC-LC works for HE-AAC streams, we pretend that we're dealing with AAC-LC and
                // hope for the best. In practice this often works.
                // See: https://github.com/google/ExoPlayer/issues/774
                // See: https://github.com/google/ExoPlayer/issues/1383
                Log.w(TAG, "Detected audio object type: " + audioObjectType + ", but assuming AAC LC.");
                audioObjectType = 2;
            }

            adtsScratch.skipBits(5);
            int channelConfig = adtsScratch.readBits(3);

            byte[] audioSpecificConfig =
                    AacUtil.buildAudioSpecificConfig(
                            audioObjectType, firstFrameSampleRateIndex, channelConfig);
            AacUtil.Config aacConfig = AacUtil.parseAudioSpecificConfig(audioSpecificConfig);
            Format format =
                    new Format.Builder()
                            .setId(formatId)
                            .setSampleMimeType(MimeTypes.AUDIO_AAC)
                            .setCodecs(aacConfig.codecs)
                            .setChannelCount(aacConfig.channelCount)
                            .setSampleRate(aacConfig.sampleRateHz)
                            .setInitializationData(Collections.singletonList(audioSpecificConfig))
                            .setLanguage(language)
                            .setRoleFlags(roleFlags)
                            .build();
            // In this class a sample is an access unit, but the MediaFormat sample rate specifies the
            // number of PCM audio samples per second.
            sampleDurationUs = (C.MICROS_PER_SECOND * 1024) / format.sampleRate;
            output.format(format);
            hasOutputFormat = true;
        } else {
            adtsScratch.skipBits(10);
        }

        adtsScratch.skipBits(4);
        int sampleSize = adtsScratch.readBits(13) - 2 /* the sync word */ - HEADER_SIZE;
        if (hasCrc) {
            sampleSize -= CRC_SIZE;
        }

        setReadingSampleState(output, sampleDurationUs, 0, sampleSize);
    }

    /**
     * Reads the rest of the sample
     */
    private void readSample(ParsableByteArray data) {
        int bytesToRead = min(data.bytesLeft(), sampleSize - bytesRead);
        currentOutput.sampleData(data, bytesToRead);
        bytesRead += bytesToRead;
        if (bytesRead == sampleSize) {
            // packetStarted method must be called before reading samples.
            checkState(timeUs != C.TIME_UNSET);
            currentOutput.sampleMetadata(timeUs, C.BUFFER_FLAG_KEY_FRAME, sampleSize, 0, null);
            timeUs += currentSampleDuration;
            setFindingSampleState();
        }
    }

    private void assertTracksCreated() {
        Assertions.checkNotNull(output);
        Util.castNonNull(currentOutput);
        Util.castNonNull(id3Output);
    }
}

